# model with no Xs
# code originally from estimate_xs.R as of 06/12/2023

#### Set-Up ####

rm(list = ls())
Sys.info()[7]

# Set file paths 
DATA <- "../external_links/real"
TEMP <- "../temp"
RESULTS <- "../output"

# Go to code folder
setwd("./")

sink("estimate_xs.txt")

folder <- paste0(TEMP, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}
folder <- paste0(RESULTS, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}
folder <- paste0(RESULTS, "/estimate_w_xs/figures")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}

Sys.info()[7]
Sys.time()

# Packages
library(MASS)
library(tmvtnorm)
library(haven)
library(dplyr)
library(tidyr)
library(nloptr)
library(corpcor)
library(stargazer)
library(pracma)
library(foreign)
library(writexl)
library(broom)
library(png)
library(tidyverse)

#### Load Data + Assumptions ####
Sys.info()[7]

# analysis <- read_dta(paste0(TEMP, "/estimate_w_xs/model_sample_recon.dta")) %>%
analysis <- read_dta(paste0(DATA, "/model_sample.dta")) %>%
  mutate(oop_i = ifelse(is.na(oop_i)==TRUE, Inf, oop_i)) %>%
  mutate(age_25_34 = ifelse(age >= 25 & age < 35, 1, 0)) %>%
  mutate(age_35_plus = ifelse(age>35,1,0)) %>%
  mutate(some_college = ifelse(educ == 1 | is.na(educ), 0, 1)) %>%
  mutate(missing_college = ifelse(is.na(educ), 1, 0)) %>%
  mutate(any_prev_birth_issue = ifelse(prev_concern == 1 | prev_any_q_icd == 1 | prev_mis_still == 1, 1,0)) %>%
  mutate(inc_quartile_4 = ifelse(inc_quartile != 4 | is.na(inc_quartile), 0, 1)) %>%
  mutate(full_college = ifelse(educ != 3 | is.na(educ), 0, 1))

print(DATA)
pa <- 0.005
pfp <- 0.01
pfn <- 0.01

J <- 1 
#file.remove(paste0(TEMP, "/Prenatal/estimate_w_xs/estimates.csv"))
#file.remove(paste0(RESULTS, "/Prenatal/estimate_w_xs/estimates.csv"))


#### Load Necessary Model Functions ####

source("r_functions/production_model_xs_new.R")

#### ESTIMATE ####
print("Estimate on actual data")

# # to load moments:
# actual <- read_dta(paste0(DATA, "/obs_moments_export.dta"))
# print("loaded actual moments")
#   data_moments <- actual[[2]]
#   weights <- actual[[3]]
#   rm(actual)

# to calculate moments:
# Calculate actual moments in data + their weights
actual <- moments_wgts(data=analysis)
  data_moments <- actual[[1]]
  weights <- actual[[2]]
  rm(actual)
print("calculated actual moments")
data_moments_table <- data.frame(data_moments, weights)
write.csv(data_moments_table, file=paste0(RESULTS, "/estimate_w_xs/data_moments_table.csv"), na=".", row.names=TRUE)

# print("table(analysis$did_nipt, analysis$wave)")
# table(analysis$did_nipt, analysis$wave)

# Define objective functions
target = function(x) obj(x, analysis, data_moments, weights, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag=1, info=FALSE)
g_target = function(x) nl.grad(x, target)

# Bounds when using the mean shifter! 
# when we add a new parameter, add in upper and lower bounds for the parameter
# for categorical variables, verify the number of distinct values, and add that number of elements to the bounds
# (mu_a, mu_c, se_a, se_c, rho, psi, coef of Xs)

# Bounds for when imposing coefficients for XS = 0 (i.e. estimate without Xs)
lb = c(-300000, -300000, 1, 1, -0.9, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
ub = c(0, 0, 100000, 100000, 0.9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)

par0 <- c(-126772.04, -201776.45, 1, 61968.99,  0.8795, 0.9569, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)

divider <- c("####################")
divider <- t(divider)

# Wrapper function to use a global method first, follower by a local gradient method.
# alg1 and alg2 may be any algorithm string supported by NLOPTR.
# normally, o1 maxtime == 2419200
doublesolve <- function(
  x0,
  eval_f,
  eval_grad=NA,
  lb,
  ub,
  alg1="NLOPT_GN_DIRECT_NOSCAL",
  alg2="NLOPT_LN_SBPLX"
) {
  o1 = list("maxeval"=10000, # used to be maxtime 3600 maxeval 10000000
            "xtol_rel"=0,
            "algorithm"=alg1,
            "pop.size" = (length(x0) + 1) * 100, # If convergence is too slow, can
            # lower the 100 to ~10ish
            "print_level"=3, 
            "ranseed" = 4792
  )

  o2 = list("maxeval"=100000,
            "xtol_rel"=1.0e-8,
            "algorithm"=alg2,
            "print_level"=3,
            "ranseed" = 2837
  )
  set.seed(101)
  write.table(divider, file=paste0(TEMP, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  write.table(divider, file=paste0(RESULTS, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  s1 = nloptr(x0=x0, eval_f=eval_f, lb=lb, ub=ub, opts=o1)
  print("Results of Global optimization")
  print(s1)
 
  write.table(divider, file=paste0(TEMP, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  write.table(divider, file=paste0(RESULTS, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  s2 = nloptr(x0=s1$solution, eval_f=eval_f, lb=lb, ub=ub, opts=o2)
  print("Results of Local optimization")
  print(s2)

  return(list(s1, s2))
}

####### Normalize X's variables #############
analysis <- analysis %>%
  mutate(full_college_norm = (full_college - mean(full_college))) %>%
  mutate(fc_n = (full_college - mean(full_college))) %>%
  mutate(inc_quartile_4_norm = (inc_quartile_4 - mean(inc_quartile_4))) %>%
  mutate(married_norm = (married - mean(married))) %>%
  mutate(mom_foreign_norm = (mom_foreign - mean(mom_foreign))) %>%
  mutate(any_prev_birth_issue_norm = (any_prev_birth_issue - mean(any_prev_birth_issue))) %>%
  mutate(prev_kids_norm = (dv_prev_kids - mean(dv_prev_kids))) %>%
  mutate(age_35_plus_norm = (age_35_plus - mean(age_35_plus)))
f_a = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
f_c = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
x_vars <- c("full_college_norm", "inc_quartile_4_norm", "married_norm", "mom_foreign_norm", "any_prev_birth_issue_norm", "age_35_plus_norm", "prev_kids_norm")

###### Optimize ####
# To use parameters inclusive of mean shifting:
# Using f_a = ~ a_i + age
# and   f_c = ~ c_i + age
# Means TWO new parameters need to be added to the end of this 
# vector below. The 2.0 will be the coefficient on age in f_a,
# and the -3.0 will be the coefficient on age in f_c.

system.time(estimate <- doublesolve(par0, target, g_target, lb, ub))
global <- estimate[1]
local <- estimate[2]
global_list <- global[[1]]
global_est <- global_list$solution
global_est <- as.data.frame(global_est)
write.table(global_est, file = paste0(RESULTS, "/estimate_w_xs/estimates_global_no_xs.csv"), sep = ",", col.names = FALSE, row.names = FALSE)

# Extract results
est <- local[[1]]
results_1 <- est$solution
opt_est <- est$solution
results_1 <- append(results_1, est$objective)
results_1 <- append(results_1, est$iterations)
results_1 <- as.data.frame(results_1)
write_csv(results_1, file = paste0(RESULTS, "/estimate_w_xs/results_no_xs.csv"))
rm(estimate)

# Extract some diagnostics from local
set.seed(4793285)
ans_local = obj(est$solution, analysis, data_moments, weights, f_a = f_a, f_c = f_c, x_vars = x_vars, info=TRUE)

# draw <- draw_a_c(par=est$solution, data=analysis, f_a=f_a, f_c=f_c, unique_flag=1, x_vars=x_vars)
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/histogram_a.pdf"))
# hist(draw$a_i, main=NULL, xlab="a_i", col="darkred", freq=FALSE)
# dev.off()
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/histogram_c.pdf"))
# hist(draw$c_i, main=NULL, xlab="c_i", col="blue4", freq=FALSE)
# dev.off()
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/histogram_a_trim.pdf"))
# hist(draw$a_i, xlim=c(-500000, 0), main=NULL, xlab="a_i", col="darkred", freq=FALSE)
# dev.off()
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/histogram_c_trim.pdf"))
# hist(draw$c_i, xlim=c(-500000, 0), main=NULL, xlab="c_i", col="blue4", freq=FALSE)
# dev.off()
# draw <- data.frame(draw) %>%
#   select(a_i, c_i) %>%
#   gather(key, value) # use tidyr::gather to convert from wide to long format
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/density_a_c.pdf"))
# ggplot(draw, aes(value, colour = key)) +
#   geom_density(show.legend = T) +
#   theme_minimal() +
#   scale_color_manual(values = c(a_i = "darkred", c_i = "blue4"))
# dev.off()
# pdf(file=paste0(RESULTS, "/estimate_w_xs/figures/density_a_c_trim.pdf"))
# ggplot(draw, aes(value, colour = key)) +
#   geom_density(show.legend = T) +
#   theme_minimal() +
#   xlim(-500000, 0) +
#   scale_color_manual(values = c(a_i = "darkred", c_i = "blue4"))
# dev.off()

#####################################
# Construct local opt moment comparison table #
#####################################
moments_local = data.frame(name = ans_local$moment_names,
                     model_moments = ans_local$model,
                     data_moments = ans_local$data,
                     weight = weights)

# # Write the analysis table to disk
write.csv(moments_local, file=paste0(RESULTS, "/estimate_w_xs/moments_local_opt_global2.csv"), na=".", row.names=TRUE)

### TEMPORARILY CHANGE EST_PAR DEFINITION
opt_est <- est$solution
par_est <- opt_est

write.csv(par_est, file=paste0(TEMP, "/estimate_w_xs/nipt_estimates_rec.csv"), na=".", row.names=TRUE)
write.csv(par_est, file=paste0(RESULTS, "/estimate_w_xs/nipt_estimates_rec.csv"), na=".", row.names=TRUE)

set.seed(101)
actual <- analysis %>%
  dplyr::select(pregnancy, bin_number, p_i, fetus_risk, wave, did_nipt, did_invasive, oop_i, policy_id, policy_regime, all_of(x_vars))
simulated <- model_decisions(par=par_est, data=actual, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag = 1) %>%
  as.data.frame() %>%
  dplyr::select(a_i, c_i, pred_nipt, pred_invasive)

# altogether
all <- cbind(actual, simulated)

temp<- actual

### simulate 100 draws per pregnancy
K <- 100
actual_1 <- actual
if(K > 1){
  for (i in 1:(K-1)) {
    actual <- rbind(actual, actual_1)
  }
}
rm(i, actual_1)

set.seed(101)
simulated_100 <- model_decisions(par=par_est, data=actual, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag = 1) %>%
  as.data.frame() %>%
  dplyr::select(a_i, c_i, pred_nipt, pred_invasive)

all_100 <- cbind(actual, simulated_100)

# create X*betas
betas <- par_est[-1:-6]
a_beta <- betas[1:7]
c_beta <- betas[8:14] 
xs_only <- all %>%
  select(all_of(x_vars))
x_beta_a <- data.frame(mapply("*", xs_only, a_beta))
colnames(x_beta_a) <- c("full_college_norm_a", "inc_quartile_4_norm_a", "married_norm_a", "mom_foreign_norm_a", "any_prev_birth_issue_norm_a", "age_35_plus_norm_a", "prev_kids_norm_a")
x_beta_c <- data.frame(mapply("*", xs_only, c_beta))
colnames(x_beta_c) <- c("full_college_norm_c", "inc_quartile_4_norm_c", "married_norm_c", "mom_foreign_norm_c", "any_prev_birth_issue_norm_c", "age_35_plus_norm_c", "prev_kids_norm_c")
x_beta <- cbind(x_beta_a, x_beta_c)

risk_plot_output <- cbind(all, x_beta) %>%
  mutate(sum_xbeta_a = full_college_norm_a + inc_quartile_4_norm_a + married_norm_a + mom_foreign_norm_a + any_prev_birth_issue_norm_a + age_35_plus_norm_a + prev_kids_norm_a) %>%
  mutate(sum_xbeta_c = full_college_norm_c + inc_quartile_4_norm_c + married_norm_c + mom_foreign_norm_c + any_prev_birth_issue_norm_c + age_35_plus_norm_c + prev_kids_norm_c) 

rm(actual, simulated)
write.csv(all, file=paste0(TEMP, "/estimate_w_xs/model_predict_disagg_no_xs.csv"), na=".", row.names=TRUE)
write.csv(all, file=paste0(RESULTS, "/estimate_w_xs/model_predict_disagg_no_xs.csv"), na=".", row.names=TRUE)

write.csv(risk_plot_output, file=paste0(RESULTS, "/estimate_w_xs/risk_plot_no_xs.csv"), na=".", row.names=TRUE)
write.csv(all_100, file=paste0(RESULTS, "/estimate_w_xs/risk_plot_100_no_xs.csv"), na=".", row.names=TRUE)


#### EXIT ####
print("finished estimate_xs!")
sink()
#rm(list = ls())
